#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  9 12:02:36 2020

@author: agyeman
"""

import numpy as np 
from casadi import *
from Parameters_1D import *
from Model_1D import *
from ZoneMPC_MS import *
from matplotlib import pyplot as plt

import time


totalDepth, axialNodes, axialDistance, totalNodes, numOfActuators, interPoints = spatialVariables_1D()
samplingTime, samplingTimeInternal, internalTimeSteps, totTimeSteps = temporalVariable() #time parameters

soilPars = loamySoil()

#Symbolic polar model and the casadi integrator
headSym = SX.sym('h', totalNodes)
inpSym = SX.sym('u', numOfActuators)
cropSym = SX.sym('c',1)
evapSym = SX.sym('e',1)

polarModelSym = RichardsPolar_1D(headSym, inpSym, cropSym, evapSym)
ODE = {'x': headSym, 'p': vertcat(inpSym, cropSym, evapSym), 'ode': polarModelSym}
opts = {'tf': samplingTimeInternal, 'regularity_check': True}
I = integrator('I', 'cvodes', ODE, opts)

#State and the arrays to store the results
x0 = -0.96*np.ones(totalNodes)
uArray = np.zeros((totTimeSteps, numOfActuators))

xArray_CL = np.zeros((totTimeSteps+1, totalNodes))
xArray_CL[0,:] = x0

sllb = np.zeros((totTimeSteps, totalNodes))
slub = np.zeros((totTimeSteps, totalNodes))


xArray_OL=np.zeros((totTimeSteps+1, totalNodes))
xArray_OL[0,:] = x0


cropCoeff = 0.88*np.ones(totTimeSteps)
refEvap = 3.102e-08*np.ones(totTimeSteps)


times = samplingTimeInternal*totTimeSteps*np.linspace(0,1,totTimeSteps+1)

#Tuning matrices

Q = 0*np.eye(totalNodes)

QUpper = 10000*np.eye(totalNodes)
QLower = 100000*np.eye(totalNodes)
QUpper_new = 10000 
QLower_new = 100000

#ZoneUpper=np.array([-0.15,-0.45,-0.60,-0.62])
ZoneUpper=np.array([-0.597,-0.609,-0.601,-0.14])
ZoneLower=np.array([-0.81,-0.81,-0.81,-0.63])

xlb=np.array([-inf,-inf,-inf,-inf])
xub=np.array([inf,inf,inf,inf])

slub=np.array([inf,inf,inf,inf])
sllb=np.array([0,0,0,0])


R = 100000*np.eye(numOfActuators)

#Reference points and the input and state constraints

us = -1.0e-07*np.ones(numOfActuators)
xs = np.array([-0.676, -0.681,-0.691, -0.705])
xlb = 8*xs[:totalNodes]
xub = 0.02*xs[:totalNodes]
ulb = 200*us
uub = 0.001*us

def inputConstraint(i,ub,lb):
    if i%8640<=35:
        UB=ub
        LB=lb
    else:
        UB=np.zeros(numOfActuators)
        LB=np.zeros(numOfActuators)
    return UB,LB


for i in range(totTimeSteps):
    if i%8640<=35:
        print(i)
        #w, lbw, ubw, G, J, Guess= multipleShooting(xArray_CL[i,:], xs, us, I, Q, R,xub, xlb,uub,ulb)
    
        ub_ctr, lb_ctr=inputConstraint(i,uub,ulb)
        w, lbw, ubw, G, J, Guess, lbg,ubg = ZoneMPC_multiShooting(xArray_CL[i,:],xs,us, I, Q, 
                          R,QUpper,QLower, xub, xlb, 
                          ub_ctr, lb_ctr,ZoneLower,ZoneUpper,sllb, slub, cropCoeff[i], refEvap[i])
    
        x_ol, u,sL,sU = createAndSolveNLP(Guess,w, lbw, ubw, G, J,lbg,ubg)
    
        print(u)
    
        xArray_OL[i+1,:] = x_ol[1,:]
        uArray[i,:] = u[0]
        #sllb[i] = sL
        #slub[i] = sU
    
        r=I(x0=xArray_CL[i,:],p=vertcat(u[0], cropCoeff[i], refEvap[i]))
        xArray_CL[i+1,:]=r['xf'].full().ravel()
    else:
        print(i)
        u=np.array([0])
        uArray[i,:] = u
        r=I(x0=xArray_CL[i,:],p=vertcat(u, cropCoeff[i], refEvap[i]))
        xArray_CL[i+1,:]=r['xf'].full().ravel()
        


#Plots

rootzone_uppper=-0.597*np.ones(totTimeSteps)
rootzone_lower = -0.81*np.ones(totTimeSteps)

plt.plot(rootzone_lower,'r')
plt.plot(rootzone_uppper,'r')
plt.plot(xArray_CL[:,0],'--b')

